import os
import wandb
from dotenv import load_dotenv
from accelerate import Accelerator
import matplotlib.pyplot as plt
from io import BytesIO
from PIL import Image
import numpy as np

load_dotenv(verbose=True)

from model.utils.misc import is_main_process
from model.utils.singleton import Singleton

__all__ = [
    'WandbLogger',
    'wandb_logger'
]

class WandbLogger(metaclass=Singleton):
    def __init__(self):
        self.is_main_process = True

    def init_logger(self, project, name, config, dir, accelerator: Accelerator = None):
        if accelerator is None:
            self.is_main_process = is_main_process()
        else:
            self.is_main_process = accelerator.is_local_main_process

        if self.is_main_process:
            wandb.login(key=os.getenv("WANDB_API_KEY"))
            wandb.init(project=project, name=name, config=config, dir=dir)

    def log(self, log_dict):
        if self.is_main_process:
            wandb.log(log_dict)

    def finish(self):
        if self.is_main_process:
            wandb.finish()
    def load_image(self, key, plt_figure):
        if self.is_main_process:
            
            buf = BytesIO()
            plt_figure.savefig(buf, format="png", dpi=100)
            buf.seek(0)

            image = Image.open(buf)
            image = np.array(image)

            wandb.log({key: image})
            plt.close(plt_figure)

wandb_logger = WandbLogger()